{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6fa6fb7f",
   "metadata": {},
   "source": [
    "# Naive RAG\n",
    "\n",
    "**절차**\n",
    "\n",
    "1. Naive RAG 수행\n",
    "\n",
    "![langgraph-naive-rag](assets/langgraph-naive-rag.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f21c872b",
   "metadata": {},
   "source": [
    "## 환경 설정"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c9b2e9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install -U langchain-teddynote"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "064d5c8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# API 키를 환경변수로 관리하기 위한 설정 파일\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "# API 키 정보 로드\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "562b0043",
   "metadata": {},
   "outputs": [],
   "source": [
    "# LangSmith 추적을 설정합니다. https://smith.langchain.com\n",
    "# !pip install -qU langchain-teddynote\n",
    "from langchain_teddynote import logging\n",
    "\n",
    "# 프로젝트 이름을 입력합니다.\n",
    "logging.langsmith(\"CH17-LangGraph-Structures\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06468c1c",
   "metadata": {},
   "source": [
    "## 기본 PDF 기반 Retrieval Chain 생성\n",
    "\n",
    "여기서는 PDF 문서를 기반으로 Retrieval Chain 을 생성합니다. 가장 단순한 구조의 Retrieval Chain 입니다.\n",
    "\n",
    "단, LangGraph 에서는 Retirever 와 Chain 을 따로 생성합니다. 그래야 각 노드별로 세부 처리를 할 수 있습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f905df18",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rag.pdf import PDFRetrievalChain\n",
    "\n",
    "# PDF 문서를 로드합니다.\n",
    "pdf = PDFRetrievalChain([\"data/SPRI_AI_Brief_2023년12월호_F.pdf\"]).create_chain()\n",
    "\n",
    "# retriever와 chain을 생성합니다.\n",
    "pdf_retriever = pdf.retriever\n",
    "pdf_chain = pdf.chain"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa6f7524",
   "metadata": {},
   "source": [
    "먼저, pdf_retriever 를 사용하여 검색 결과를 가져옵니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d532337",
   "metadata": {},
   "outputs": [],
   "source": [
    "search_result = pdf_retriever.invoke(\"앤스로픽에 투자한 기업과 투자금액을 알려주세요.\")\n",
    "search_result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6f95f50",
   "metadata": {},
   "source": [
    "이전에 검색한 결과를 chain 의 context 로 전달합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d957188f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 검색 결과를 기반으로 답변을 생성합니다.\n",
    "answer = pdf_chain.invoke(\n",
    "    {\n",
    "        \"question\": \"앤스로픽에 투자한 기업과 투자금액을 알려주세요.\",\n",
    "        \"context\": search_result,\n",
    "        \"chat_history\": [],\n",
    "    }\n",
    ")\n",
    "print(answer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d047f938",
   "metadata": {},
   "source": [
    "## State 정의\n",
    "\n",
    "`State`: Graph 의 노드와 노드 간 공유하는 상태를 정의합니다.\n",
    "\n",
    "일반적으로 `TypedDict` 형식을 사용합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f19a3df5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Annotated, TypedDict\n",
    "from langgraph.graph.message import add_messages\n",
    "\n",
    "\n",
    "# GraphState 상태 정의\n",
    "class GraphState(TypedDict):\n",
    "    question: Annotated[str, \"Question\"]  # 질문\n",
    "    context: Annotated[str, \"Context\"]  # 문서의 검색 결과\n",
    "    answer: Annotated[str, \"Answer\"]  # 답변\n",
    "    messages: Annotated[list, add_messages]  # 메시지(누적되는 list)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c56d4095",
   "metadata": {},
   "source": [
    "## 노드(Node) 정의\n",
    "\n",
    "- `Nodes`: 각 단계를 처리하는 노드입니다. 보통은 Python 함수로 구현합니다. 입력과 출력이 상태(State) 값입니다.\n",
    "  \n",
    "**참고**  \n",
    "\n",
    "- `State`를 입력으로 받아 정의된 로직을 수행한 후 업데이트된 `State`를 반환합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9ef0c055",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_teddynote.messages import messages_to_history\n",
    "from rag.utils import format_docs\n",
    "\n",
    "\n",
    "# 문서 검색 노드\n",
    "def retrieve_document(state: GraphState) -> GraphState:\n",
    "    # 질문을 상태에서 가져옵니다.\n",
    "    latest_question = state[\"question\"]\n",
    "\n",
    "    # 문서에서 검색하여 관련성 있는 문서를 찾습니다.\n",
    "    retrieved_docs = pdf_retriever.invoke(latest_question)\n",
    "\n",
    "    # 검색된 문서를 형식화합니다.(프롬프트 입력으로 넣어주기 위함)\n",
    "    retrieved_docs = format_docs(retrieved_docs)\n",
    "\n",
    "    # 검색된 문서를 context 키에 저장합니다.\n",
    "    return {\"context\": retrieved_docs}\n",
    "\n",
    "\n",
    "# 답변 생성 노드\n",
    "def llm_answer(state: GraphState) -> GraphState:\n",
    "    # 질문을 상태에서 가져옵니다.\n",
    "    latest_question = state[\"question\"]\n",
    "\n",
    "    # 검색된 문서를 상태에서 가져옵니다.\n",
    "    context = state[\"context\"]\n",
    "\n",
    "    # 체인을 호출하여 답변을 생성합니다.\n",
    "    response = pdf_chain.invoke(\n",
    "        {\n",
    "            \"question\": latest_question,\n",
    "            \"context\": context,\n",
    "            \"chat_history\": messages_to_history(state[\"messages\"]),\n",
    "        }\n",
    "    )\n",
    "    # 생성된 답변, (유저의 질문, 답변) 메시지를 상태에 저장합니다.\n",
    "    return {\n",
    "        \"answer\": response,\n",
    "        \"messages\": [(\"user\", latest_question), (\"assistant\", response)],\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3f7785d",
   "metadata": {},
   "source": [
    "## 그래프 생성\n",
    "\n",
    "- `Edges`: 현재 `State`를 기반으로 다음에 실행할 `Node`를 결정하는 Python 함수.\n",
    "\n",
    "일반 엣지, 조건부 엣지 등이 있습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a6015807",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph.graph import END, StateGraph\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "\n",
    "# 그래프 생성\n",
    "workflow = StateGraph(GraphState)\n",
    "\n",
    "# 노드 정의\n",
    "workflow.add_node(\"retrieve\", retrieve_document)\n",
    "workflow.add_node(\"llm_answer\", llm_answer)\n",
    "\n",
    "# 엣지 정의\n",
    "workflow.add_edge(\"retrieve\", \"llm_answer\")  # 검색 -> 답변\n",
    "workflow.add_edge(\"llm_answer\", END)  # 답변 -> 종료\n",
    "\n",
    "# 그래프 진입점 설정\n",
    "workflow.set_entry_point(\"retrieve\")\n",
    "\n",
    "# 체크포인터 설정\n",
    "memory = MemorySaver()\n",
    "\n",
    "# 컴파일\n",
    "app = workflow.compile(checkpointer=memory)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9a15c32",
   "metadata": {},
   "source": [
    "컴파일한 그래프를 시각화 합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e09251d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_teddynote.graphs import visualize_graph\n",
    "\n",
    "visualize_graph(app)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "110daa16",
   "metadata": {},
   "source": [
    "## 그래프 실행\n",
    "\n",
    "- `config` 파라미터는 그래프 실행 시 필요한 설정 정보를 전달합니다.\n",
    "- `recursion_limit`: 그래프 실행 시 재귀 최대 횟수를 설정합니다.\n",
    "- `inputs`: 그래프 실행 시 필요한 입력 정보를 전달합니다.\n",
    "\n",
    "**참고**\n",
    "\n",
    "- 메시지 출력 스트리밍은 [LangGraph 스트리밍 모드의 모든 것](https://wikidocs.net/265770) 을 참고해주세요.\n",
    "\n",
    "아래의 `stream_graph` 함수는 특정 노드만 스트리밍 출력하는 함수입니다.\n",
    "\n",
    "손쉽게 특정 노드의 스트리밍 출력을 확인할 수 있습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2698eaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.runnables import RunnableConfig\n",
    "from langchain_teddynote.messages import invoke_graph, stream_graph, random_uuid\n",
    "\n",
    "# config 설정(재귀 최대 횟수, thread_id)\n",
    "config = RunnableConfig(recursion_limit=20, configurable={\"thread_id\": random_uuid()})\n",
    "\n",
    "# 질문 입력\n",
    "inputs = GraphState(question=\"앤스로픽에 투자한 기업과 투자금액을 알려주세요.\")\n",
    "\n",
    "# 그래프 실행\n",
    "invoke_graph(app, inputs, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6a15fa0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 그래프를 스트리밍 출력\n",
    "stream_graph(app, inputs, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a16ba031",
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = app.get_state(config).values\n",
    "\n",
    "print(f'Question: {outputs[\"question\"]}')\n",
    "print(\"===\" * 20)\n",
    "print(f'Answer:\\n{outputs[\"answer\"]}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py-test",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
